# -*- coding: utf-8 -*-
"""
Created on Wed Jun 17 16:29:00 2020

@author: Colleen
"""

from sklearn import tree
from sklearn.metrics import classification_report 
import os
os.environ["PATH"] += os.pathsep + 'F:/graphviz-2.38/release/bin'
from sklearn.externals import joblib
import pydotplus
from sklearn import svm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from xgboost import XGBClassifier
from xgboost import plot_importance
from sklearn.ensemble import RandomForestClassifier

class module(object):
    def count_precision(self,a,b):
        '''
        该函数用于计算准确率
        input：
              a,b——要对比的向量
        output:
              pre——对比后的准确率

        '''
        co = 0
        le = len(a)
        for i in range(le):
            if a[i] == b[i]:
                co = co + 1
        pre = co/le
        return pre
        

    def c45(self,train_features, test_features, train_labels, test_labels):
        '''
        该函数用于训练决策树模型以及保存模型
        input:
            features：训练集特征
            labels：训练集反馈分类
            model_name:要保存模型的名字，不包括后缀
        output:
            已保存的模型

        '''
        #训练决策树模型
        c45_clf = tree.DecisionTreeClassifier(criterion='entropy', random_state=0)
        self.c45_clf = c45_clf.fit(train_features, train_labels)
        #查看模型准确率
        y_predict = self.c45_clf.predict(test_features)
        score = classification_report(test_labels, y_predict)
        print("C4.5 Tree:{}".format(score))
        
        #保存决策树图片
        dot_data = tree.export_graphviz(self.c45_clf, out_file=None,
                             filled=True, rounded=True,
                             special_characters=True)
        graph = pydotplus.graph_from_dot_data(dot_data)
        # 保存图像到pdf文件
        graph.write_pdf("c45tree.pdf")
        
        c45_pre = self.count_precision(test_labels, y_predict)
        return c45_pre
    
    def svm(self, train_features, test_features, train_labels, test_labels):
        '''
        支持向量机
        '''
        cls = svm.LinearSVC()
        self.svm_clf = cls.fit(train_features, train_labels)
        y_predict = self.svm_clf.predict(test_features)
            #预测值和测试值打分
        score = classification_report(test_labels, y_predict)
        print("SVM:{}".format(score))
        svm_pre = self.count_precision(test_labels, y_predict)
        return svm_pre
   
    def xgboost(self, train_features, test_features, train_labels, test_labels):
        xg_model = XGBClassifier(learning_rate=0.1,
                      n_estimators=1000,         # 树的个数--1000棵树建立xgboost
                      max_depth=6,               # 树的深度
                      min_child_weight = 1,      # 叶子节点最小权重
                      gamma=0.,                  # 惩罚项中叶子结点个数前的参数
                      subsample=0.8,             # 随机选择80%样本建立决策树
                      colsample_btree=0.8,       # 随机选择80%特征建立决策树
                      objective='multi:softmax', # 指定损失函数
                      scale_pos_weight=1,        # 解决样本个数不平衡的问题
                      random_state=27            # 随机数
                       )

        self.xg_clf = xg_model.fit(train_features, train_labels, eval_set = [(test_features, test_labels)], eval_metric = "mlogloss", 
                     early_stopping_rounds = 10, verbose = True)
        y_pred = self.xg_clf.predict(test_features)
        accuracy = accuracy_score(y_pred, test_labels)
        print("Xgboost:{}".format(accuracy))
        
        fig,ax = plt.subplots(figsize=(15,15))
        plot_importance(xg_model, height=0.5, ax=ax, max_num_features=64)
        plt.show()
        
        return accuracy
    
    def randomforest(self, train_features, test_features, train_labels, test_labels):
        rfc = RandomForestClassifier(random_state=0)
        self.rm_clf = rfc.fit(train_features, train_labels)
        y_predict2 = self.rm_clf.predict(test_features)
        score_r = classification_report(test_labels, y_predict2)
        print("Randomforest:{}".format(score_r))
        rf_pre = self.count_precision(test_labels, y_predict2)
        return rf_pre
                
    def run(self, features, labels, model, model_name):
        '''
        input:
            features——输入的特征
            labels——输入的标签
            model——想要用什么模型，目前有 'all','svm','c45','xgb','rf','km'
            model_name——保存模型的文件名
        output:
            最佳模型的保存

        '''
        train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.2,random_state=0)
        if model == 'svm':
            svm_pre = self.svm(train_features, test_features, train_labels, test_labels)
            max_pre = svm_pre
            joblib.dump(self.svm_clf, model_name + '.m')#保存模型
        if model == 'c45':
            c45_pre = self.c45(train_features, test_features, train_labels, test_labels)
            max_pre = c45_pre  
            joblib.dump(self.c45_clf, model_name + '.m')#保存模型
        if model == 'xgb':
            xgb_pre = self.xgboost(train_features, test_features, train_labels, test_labels)
            max_pre = xgb_pre
            joblib.dump(self.xg_clf, model_name + '.m')#保存模型
        if model == 'rf':
            rf_pre = self.randomforest(train_features, test_features, train_labels, test_labels)
            joblib.dump(self.rf_clf, model_name + '.m')#保存模型
        if model == 'all':
            result = {}
            result['svm_precision'] = self.svm(train_features, test_features, train_labels, test_labels)
            result['c45tree_precision'] = self.c45(train_features, test_features, train_labels, test_labels)
            result['xgboost_precision'] = self.xgboost(train_features, test_features, train_labels, test_labels)
            result['randomforest_precision'] = self.randomforest(train_features, test_features, train_labels, test_labels)
            print(result)
            max_model = max(result.items(),key=lambda x:x[1])
            print(max_model,max_model[0])
            if max_model[0] == 'svm_precision': joblib.dump(self.svm_clf, model_name + '.m')#保存模型
            if max_model[0] == 'c45tree_precision': joblib.dump(self.c45_clf, model_name + '.m')#保存模型
            if max_model[0] == 'xgboost_precision': joblib.dump(self.xg_clf, model_name + '.m')#保存模型
            if max_model[0] == 'randomforest_precision': joblib.dump(self.rf_clf, model_name + '.m')#保存模型
            
    def test_tree(self, test_features, model_name):
        '''
        该函数用于测试已保存的模型
        input:
            test_features:测试集特征
            model_name:想要测试的模型文件名，不包括后缀
        output:
            测试集对应的反馈label

        '''
        clf = joblib.load(model_name + '.m')
        y_predict = clf.predict(test_features)
        return y_predict

        